package com.alibaba.alink.operator.common.distance;

import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.types.Row;

import com.alibaba.alink.common.exceptions.AkUnsupportedOperationException;
import com.alibaba.alink.common.exceptions.AkPreconditions;
import com.alibaba.alink.common.linalg.BLAS;
import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.Vector;

import java.util.Iterator;
import java.util.List;

/**
 * Haversine formula: find distance between two points on a sphere.
 * <p>
 * https://en.wikipedia.org/wiki/Haversine_formula.
 * <p>
 * haversine(theta) = sin(theta / 2)^2
 * <p>
 * h = haversine(lat1 - lat2) + Math.cos(lat1) * Math.cos(lat2) * haverSine(lon1 - lon2)
 * <p>
 * d = 2 * EARTH_RADIUS * arcsin(sqrt(h))
 * <p>
 * Latitude and longitude are in degree format. If input is a vector, it's required the first value is latitude and the
 * second value is longitude.
 */
public class HaversineDistance extends FastDistance {
	private static final long serialVersionUID = -7525846623969219548L;
	/**
	 * Earth radius.
	 */
	private static int EARTH_RADIUS = 6371;

	/**
	 * Degree to radian costant.
	 */
	private static double DEGREE_TO_RADIAN_CONSTANT = Math.PI / 180;

	/**
	 * Label size.
	 */
	private static int LABEL_SIZE = 3;

	/**
	 * Vector size;
	 */
	private static int VECTOR_SIZE = 2;

	/**
	 * haversine(theta) = sin(theta / 2)^2 = (1 - cos(theta)) / 2
	 *
	 * @param theta degree
	 * @return haversine
	 */
	static double haverSine(double theta) {
		return (1 - Math.cos(theta)) / 2;
	}

	/**
	 * Transform degree to radia.
	 *
	 * @param data degree.
	 * @return radian.
	 */
	static double degreeToRadian(double data) {
		return data * DEGREE_TO_RADIAN_CONSTANT;
	}

	private static double cal(double h) {
		return 2 * EARTH_RADIUS * Math.asin(Math.min(1.0, Math.sqrt(Math.abs(h))));
	}

	/**
	 * sin(lat), cos(lat)cos(lon), sin(lat)sin(lon)
	 */
	private static double[] vectorLabel(double latitude, double longitude) {
		double lat = degreeToRadian(latitude);
		double lon = degreeToRadian(longitude);

		double latSin = Math.sin(lat);
		double latCos = Math.cos(lat);
		double lonSin = Math.sin(lon);
		double lonCos = Math.cos(lon);

		return new double[] {latSin, latCos * lonCos, latCos * lonSin};
	}

	/**
	 * The shortest distance between two points on the surface of the Earth
	 *
	 * @param latitude1  latitude of the first point.
	 * @param longitude1 longitude of the first point.
	 * @param latitude2  latitude of the second point.
	 * @param longitude2 longitude of the second point.
	 * @return distance.
	 */
	public double calc(double latitude1, double longitude1, double latitude2, double longitude2) {
		double lat1 = degreeToRadian(latitude1);
		double lon1 = degreeToRadian(longitude1);
		double lat2 = degreeToRadian(latitude2);
		double lon2 = degreeToRadian(longitude2);

		double vLat = lat2 - lat1;
		double vLon = lon2 - lon1;
		double h = haverSine(vLat) + Math.cos(lat1) * Math.cos(lat2) * haverSine(vLon);
		return cal(h);
	}

	/**
	 * The shortest distance between two points on the surface of the Earth
	 *
	 * @param array1 [latitude, longitude] of the first point.
	 * @param array2 [latitude, longitude] of the second point.
	 * @return distance.
	 */
	@Override
	public double calc(double[] array1, double[] array2) {
		AkPreconditions.checkArgument(array1.length == VECTOR_SIZE && array2.length == VECTOR_SIZE,
			"HaversineDistance only supports vector size 2, the first value is latitude and the second value is "
				+ "longitude");

		return calc(array1[0], array1[1], array2[0], array2[1]);
	}

	/**
	 * The shortest distance between two points on the surface of the Earth
	 *
	 * @param vec1 [latitude, longitude] of the first point.
	 * @param vec2 [latitude, longitude] of the second point.
	 * @return distance.
	 */
	@Override
	public double calc(Vector vec1, Vector vec2) {
		AkPreconditions.checkArgument(vec1.size() == VECTOR_SIZE && vec2.size() == VECTOR_SIZE,
			"HaversineDistance only supports vector size 2, the first value is latitude and the second value is "
				+ "longitude");

		return calc(vec1.get(0), vec1.get(1), vec2.get(0), vec2.get(1));
	}

	@Override
	List <FastDistanceData> prepareSparseMatrixData(Tuple2 <Vector, Row> tuple,
													Iterator <Tuple2 <Vector, Row>> iterator,
													int vectorSize) {
		throw new AkUnsupportedOperationException("HaversineDistance not support sparse data!");
	}

	/**
	 * For Haversine distance, d = 2 * EARTH_RADIUS * arcsin(sqrt(h)),
	 * <p>
	 * h = haversine(lat1 - lat2) + Math.cos(lat1) * Math.cos(lat2) * haverSine(lon1 - lon2)
	 * <p>
	 * h = 0.5 * (1 - sin(lat1)*sin(lat2) - cos(lat1)cos(lon1)*cos(lat2)cos(lon2) -
	 * cos(lat1)sin(lon1)cos(lat2)sin(lon2)
	 * <p>
	 * h = 0.5 * (1 - [sin(lat1), cos(lat1)cos(lon1), cos(lat1)sin(lon1)] * [sin(lat2), cos(lat2)cos(lon2),
	 * cos(lat2)sin(lon2)]
	 * <p>
	 * d = 2 * EARTH_RADIUS * arcsin(sqrt(h))
	 * <p>
	 * So we can pre-calculate the sin(latitude), cos(latitude) * cos(longitude),cos(latitude) * sin(longitude) of the
	 * point, and when we need to calculate the distance with another point, only dot product is calculated.
	 * <p>
	 * For FastDistanceVectorData, the label is a three-dimension vector. For FastDistanceMatrixData, the label is a
	 * 3 X
	 * n DenseMatrix, n is the number of vectors saved in the matrix.
	 *
	 * @param data FastDistanceData.
	 */
	@Override
	public void updateLabel(FastDistanceData data) {
		if (data instanceof FastDistanceVectorData) {
			FastDistanceVectorData vectorData = (FastDistanceVectorData) data;
			Vector vec = vectorData.getVector();
			AkPreconditions.checkArgument(vec.size() == VECTOR_SIZE,
				"HaversineDistance only supports vector size 2, the first value is latitude and the second value is "
					+ "longitude");
			if (vectorData.label == null || vectorData.label.size() != LABEL_SIZE) {
				vectorData.label = new DenseVector(LABEL_SIZE);
			}
			vectorData.label = new DenseVector(vectorLabel(vec.get(0), vec.get(1)));
		} else if (data instanceof FastDistanceMatrixData) {
			FastDistanceMatrixData matrix = (FastDistanceMatrixData) data;
			if (matrix.label == null || matrix.label.numRows() != LABEL_SIZE || matrix.label.numCols() != matrix
				.vectors
				.numCols()) {
				matrix.label = new DenseMatrix(LABEL_SIZE, matrix.vectors.numCols());
			}
			double[] matrixData = matrix.vectors.getData();
			AkPreconditions.checkArgument(matrixData.length % VECTOR_SIZE == 0,
				"HaversineDistance only supports vector size 2, the first value is latitude and the second value is "
					+ "longitude");

			double[] normData = matrix.label.getData();
			int labelCnt = 0;
			for (int i = 0; i < matrixData.length; i += VECTOR_SIZE) {
				double[] norm = vectorLabel(matrixData[i], matrixData[i + 1]);
				normData[labelCnt++] = norm[0];
				normData[labelCnt++] = norm[1];
				normData[labelCnt++] = norm[2];
			}
		}
	}

	/**
	 * h = 0.5 * (1 - [sin(lat1), cos(lat1)cos(lon1), cos(lat1)sin(lon1)] * [sin(lat2), cos(lat2)cos(lon2),
	 * cos(lat2)sin(lon2)]
	 * <p>
	 * d = 2 * EARTH_RADIUS * arcsin(sqrt(h))
	 *
	 * @param left  single vector with label(sin(lat), cos(lat)cos(lon), cos(lat)sin(lon))
	 * @param right single vector with label(sin(lat), cos(lat)cos(lon), cos(lat)sin(lon))
	 * @return the distance
	 */
	@Override
	double calc(FastDistanceVectorData left, FastDistanceVectorData right) {
		return cal(0.5 * (1 - BLAS.dot(left.label, right.label)));

	}

	/**
	 * h = 0.5 * (1 - [sin(lat1), cos(lat1)cos(lon1), cos(lat1)sin(lon1)] * [sin(lat2), cos(lat2)cos(lon2),
	 * cos(lat2)sin(lon2)]
	 * <p>
	 * d = 2 * EARTH_RADIUS * arcsin(sqrt(h))
	 *
	 * @param leftVector   single vector with label
	 * @param rightVectors vectors with labels
	 * @param res          the distances between leftVector and all the vectors in rightVectors.
	 */
	@Override
	void calc(FastDistanceVectorData leftVector, FastDistanceMatrixData rightVectors, double[] res) {
		BLAS.gemv(-0.5, rightVectors.label, true, leftVector.label, 0.0, new DenseVector(res));
		for (int i = 0; i < res.length; i++) {
			res[i] = cal(0.5 + res[i]);
		}
	}

	/**
	 * h = 0.5 * (1 - [sin(lat1), cos(lat1)cos(lon1), cos(lat1)sin(lon1)] * [sin(lat2), cos(lat2)cos(lon2),
	 * cos(lat2)sin(lon2)]
	 * <p>
	 * d = 2 * EARTH_RADIUS * arcsin(sqrt(h))
	 *
	 * @param left  vectors with labels
	 * @param right vectors with labels
	 * @param res   the distances between all the vectors in left and all the vectors in right.
	 */
	@Override
	void calc(FastDistanceMatrixData left, FastDistanceMatrixData right, DenseMatrix res) {
		BLAS.gemm(-0.5, right.label, true, left.label, false, 0.0, res);
		double[] data = res.getData();
		for (int i = 0; i < data.length; i++) {
			data[i] = cal(0.5 + data[i]);
		}
	}

	@Override
	void calc(FastDistanceVectorData left, FastDistanceSparseData right, double[] res) {
		throw new AkUnsupportedOperationException("HaversineDistance not support sparse data!");
	}

	@Override
	void calc(FastDistanceSparseData left, FastDistanceSparseData right, double[] res) {
		throw new AkUnsupportedOperationException("HaversineDistance not support sparse data!");
	}
}
